library(sf)
library(haven)
library(dplyr)
library(ggplot2)

#### Data wrangling ####

# read in shape file
wbworld <- read_sf(dsn = "Shape", layer = "WB_countries_Admin0_10m")

# read in stata dataset
imf <- read_dta("IVTAX_Dataset.dta")

# create dummy variables for a) each tax measures and b) each non-institutional tax measure
imf <- imf %>%
  mutate(tax_measure_ind = ifelse(!is.na(tax_measure) & tax_type != "[None]", 1, 0)) %>%
  mutate(tax_prog_dum = ifelse(tax_prog == "Progressive", 1, 0)) %>%
  mutate(tax_reg_dum = ifelse(tax_prog == "Regressive", 1, 0))

# Create average count of tax measures by country and include income group
tax <- imf %>%
  filter(!is.na(income)) %>%  # Only keep entries with valid income classification
  mutate(income_group = ifelse(income == "High income", "High income", "Low/Middle income")) %>%
  select(ccode, year, tax_measure_ind, tax_prog_dum, tax_reg_dum, income_group) %>%
  group_by(ccode, year) %>%
  summarize(
    tax_measure_ind_sum = sum(tax_measure_ind, na.rm = TRUE),
    tax_prog_dum_sum = sum(tax_prog_dum, na.rm = TRUE),
    tax_reg_dum_sum = sum(tax_reg_dum, na.rm = TRUE),
    income_group = first(income_group),  # Preserve income group per country
    .groups = "drop_last"
  ) %>%
  group_by(ccode) %>%
  summarize(
    tax_measure_ind_mean = mean(tax_measure_ind_sum, na.rm = TRUE),
    tax_prog_dum_mean = mean(tax_prog_dum_sum, na.rm = TRUE),
    tax_reg_dum_mean = mean(tax_reg_dum_sum, na.rm = TRUE),
    income_group = first(income_group),  # Keep one label per country
    .groups = "drop"
  ) %>%
  mutate(tax_diff = tax_prog_dum_mean - tax_reg_dum_mean)

## Merge stata data into shape data
# rename country codes in shape data to match country codes in stata data
wbworld <- wbworld %>%
  mutate(WB_A3 = recode(WB_A3,
                        "ADO" = "AND",
                        "ZAR" = "COD",
                        "KSV" = "XKX",
  ))

# see which countries won't merge from shape data
mismatch_wbworld <- anti_join(wbworld, imf, by = c("WB_A3" = "ccode")) %>%
  select(WB_A3, WB_NAME)

# see which countries won't merge from stata data
mismatch_cond <- anti_join(imf, wbworld, by = c("ccode" = "WB_A3")) %>%
  select(ccode, cname)

# use left join based on common country name variable
world2 <- left_join(wbworld, tax, by = c("WB_A3" = "ccode")) %>%
  select(starts_with("tax"), WB_A3, WB_NAME, Shape_Leng, Shape_Area, geometry)



#### Factual statements ####

## Two countries stand out for their high tax progressivity scores: United States and Brazil. 
## On the regressive side, India and Belize received the most regressive recommendations.

# Sort the dataset in descending order of tax_diff
tax_sorted <- tax[order(-tax$tax_diff), ]

# Print countries along with tax_diff values
print(tax_sorted[, c("ccode", "tax_diff")], n = Inf)

## high-income countries tend to have higher progressivity scores (mean = 0.1, n = 60, SD = 1.5),
## whereas low- and middle-income countries tend to score lower (mean = –0.4, n = 65, SD = 1.59). 

# Summarize tax_diff across income groups
summary_stats <- tax %>%
  group_by(income_group) %>%
  summarize(
    mean_tax_diff = mean(tax_diff),
    median_tax_diff = median(tax_diff),
    sd_tax_diff = sd(tax_diff),
    n = n()
  )
print(summary_stats)

## 39 countries (31%) received predominantly progressive advice—24 of them high-income, and 15 low- or middle-income. 
## 59 countries (47%) received mostly regressive advice, including 23 high-income and 36 low- or middle-income countries. 
## 27 countries (22%) received a roughly balanced mix, with 13 high-income and 14 low- or middle-income.

# Categorize and count progressive/regressive/neutral within each income group
tax %>%
  filter(!is.na(tax_diff), !is.na(income_group)) %>%
  mutate(category = case_when(
    tax_diff > 0 ~ "Progressive",
    tax_diff == 0 ~ "Neutral",
    tax_diff < 0 ~ "Regressive"
  )) %>%
  count(income_group, category, name = "count") %>%
  group_by(income_group) %>%
  mutate(percentage = round(100 * count / sum(count), 1)) %>%
  print()



#### Heat maps ####

## Figure 1. Number of tax recommendations per country 

# Define color bins and labels similar to leaflet palette
bins <- c(0, 1, 5, 9, 13, 17)
labels <- c("0 to <1", "1 to <5", "5 to <9", "9 to <13", "13 to <17")
colors <- colorRampPalette(c("#ffffff", "#fca5a5", "#fb7185", "#e11d48", "#7f1d1d"))(length(labels))

# Create a binned factor variable for plotting
world2_binned_tax <- world2 %>%
  mutate(tax_measure_bin = cut(
    tax_measure_ind_mean,
    breaks = bins,
    labels = labels,
    include.lowest = TRUE,
    right = FALSE
  ))

# Map
p_taxmeasures <- ggplot(world2_binned_tax) +
  geom_sf(aes(fill = tax_measure_bin), color = "black", size = 0.3) +
  scale_fill_manual(
    values = colors,
    name = "Number of IMF tax measures (average per Article IV)",
    na.value = "lightgrey",
    labels = c(labels, "No data")
  ) +
  guides(fill = guide_legend(
    title.position = "top",
    title.hjust = 0.5,
    nrow = 1
  )) +
  theme_minimal() +
  theme(
    legend.position = "bottom",
    legend.box.background = element_rect(color = "black", fill = "white"),
    legend.box.margin = margin(5, 5, 5, 5),
    legend.title = element_text(size = 12, face = "bold"),
    legend.text = element_text(size = 10),
    panel.grid = element_blank(),
    axis.text = element_blank(),
    axis.ticks = element_blank(),
    axis.title = element_blank()
  )

# Save the figure
ggsave("Figures/IMFTAX_Figure1.png", p_taxmeasures, width = 12, height = 8, dpi = 300, bg = "white")


## Figure 3. Tax progressivity score per country

# Define ordered categories and color mapping
categories <- c(
  "Highly progressive (4 to <7)",
  "Progressive (>0 to <4)",
  "Neutral (0)",
  "Regressive (<0 to >-4)",
  "Highly regressive (≤-4 to >-7)"
)

fill_colors <- setNames(
  c("#33a02c", "#b2df8a", "white", "#f7c6d4", "#cc3366"),
  categories
)

# Categorize tax_diff & convert to ordered factor
world2_binned <- world2 %>%
  mutate(tax_diff_category = factor(case_when(
    between(tax_diff, 4, 7)           ~ categories[1],
    tax_diff > 0 & tax_diff < 4       ~ categories[2],
    tax_diff == 0                     ~ categories[3],
    tax_diff < 0 & tax_diff > -4      ~ categories[4],
    between(tax_diff, -7, -4)         ~ categories[5],
    TRUE                              ~ NA_character_
  ), levels = categories))

# Plot
p <- ggplot(world2_binned) +
  geom_sf(aes(fill = tax_diff_category), color = "black", size = 0.3) +
  scale_fill_manual(
    values = fill_colors,
    name = "Tax progressivity score (average per Article IV)",
    na.value = "lightgrey",
    labels = c(categories, "No data")
  ) +
  guides(fill = guide_legend(
    title.position = "top", title.hjust = 0.5, nrow = 2, byrow = TRUE
  )) +
  theme_minimal() +
  theme(
    legend.position = "bottom",
    legend.box.background = element_rect(color = "black", fill = "white"),
    legend.box.margin = margin(5, 5, 5, 5),
    legend.title = element_text(size = 12, face = "bold"),
    legend.text = element_text(size = 10),
    panel.grid = element_blank(),
    axis.text = element_blank(),
    axis.ticks = element_blank(),
    axis.title = element_blank()
  )

ggsave("Figures/IMFTAX_Figure3.png", p, width = 12, height = 8, dpi = 300, bg = "white")
